{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Trax Demo",
      "version": "0.3.2",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ySEmBgmqMSIJ",
        "colab_type": "text"
      },
      "source": [
        "##### Copyright 2019 Google LLC.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at\n",
        "\n",
        "https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "o4WGihMLneYq",
        "colab_type": "text"
      },
      "source": [
        "# Trax: Train Models in JAX\n",
        "\n",
        "[JAX](https://github.com/google/jax) allows you to write [numpy](https://www.numpy.org/) and run it fast on accelerators.\n",
        "\n",
        "This makes ML research more *fun* and *clear* so we made\n",
        "* [Trax](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/trax): a library of models in JAX.\n",
        "\n",
        "In this demo we show how to:\n",
        "* Train a Trax model on a toy copy problem.\n",
        "* Decode from a pre-trained [Transformer](https://arxiv.org/abs/1706.03762) language model.\n",
        "* Define [Transformer](https://arxiv.org/abs/1706.03762) from scratch in Trax.\n",
        "* Do research in Trax: play with hard attention to see how it impacts training and results.\n",
        "\n",
        "We would like your feedback!\n",
        "* What are the parts you like or dislike in JAX and Trax?\n",
        "* Will you start doing your research in Trax? If not, why? What would change your mind?\n",
        "* What should we focus on? Speed, cleanliness, memory use?\n",
        "* If you cannot tell us in person, please add your feedback on [this github issue](https://github.com/tensorflow/tensor2tensor/issues/1478).\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8YQw0hySTVlK",
        "colab_type": "text"
      },
      "source": [
        "## Installs\n",
        "\n",
        "We install jax and trax and download a pretrained model and vocab file."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vAWJVzYRnbDU",
        "colab_type": "code",
        "outputId": "6cdeff6f-3fc9-406f-feaf-fd1f8d9de775",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 578
        }
      },
      "source": [
        "# Install JAX for GPU and Tensor2Tensor.\n",
        "!pip install --upgrade -q https://storage.googleapis.com/jax-wheels/cuda100/jaxlib-0.1.14-cp36-none-linux_x86_64.whl\n",
        "!pip install --upgrade -q jax==0.1.27\n",
        "!pip install --upgrade -q tensor2tensor==1.13.4\n",
        "# Grab language-model checkpoint and vocab file.\n",
        "!rm -f model.pkl\n",
        "!wget https://storage.googleapis.com/traxdemo/model.pkl\n",
        "!wget https://storage.googleapis.com/traxdemo/vocab.lm1b.en.32768\n",
        "# Show GPU type.\n",
        "!nvidia-smi -L"
      ],
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "\u001b[K     |████████████████████████████████| 44.6MB 1.2MB/s \n",
            "\u001b[K     |████████████████████████████████| 174kB 3.5MB/s \n",
            "\u001b[K     |████████████████████████████████| 61kB 24.4MB/s \n",
            "\u001b[?25h  Building wheel for jax (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Building wheel for opt-einsum (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "\u001b[K     |████████████████████████████████| 1.4MB 3.4MB/s \n",
            "\u001b[K     |████████████████████████████████| 686kB 45.8MB/s \n",
            "\u001b[K     |████████████████████████████████| 143kB 40.2MB/s \n",
            "\u001b[K     |████████████████████████████████| 296kB 32.6MB/s \n",
            "\u001b[?25h  Building wheel for pypng (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "--2019-05-14 22:57:21--  https://storage.googleapis.com/traxdemo/model.pkl\n",
            "Resolving storage.googleapis.com (storage.googleapis.com)... 209.85.234.128, 2607:f8b0:4001:c12::80\n",
            "Connecting to storage.googleapis.com (storage.googleapis.com)|209.85.234.128|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 211170062 (201M) [application/octet-stream]\n",
            "Saving to: ‘model.pkl’\n",
            "\n",
            "model.pkl           100%[===================>] 201.39M   101MB/s    in 2.0s    \n",
            "\n",
            "2019-05-14 22:57:23 (101 MB/s) - ‘model.pkl’ saved [211170062/211170062]\n",
            "\n",
            "--2019-05-14 22:57:23--  https://storage.googleapis.com/traxdemo/vocab.lm1b.en.32768\n",
            "Resolving storage.googleapis.com (storage.googleapis.com)... 64.233.183.128, 2607:f8b0:4001:c07::80\n",
            "Connecting to storage.googleapis.com (storage.googleapis.com)|64.233.183.128|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 297760 (291K) [application/octet-stream]\n",
            "Saving to: ‘vocab.lm1b.en.32768’\n",
            "\n",
            "vocab.lm1b.en.32768 100%[===================>] 290.78K  --.-KB/s    in 0.007s  \n",
            "\n",
            "2019-05-14 22:57:24 (40.8 MB/s) - ‘vocab.lm1b.en.32768’ saved [297760/297760]\n",
            "\n",
            "GPU 0: Tesla T4 (UUID: GPU-1959cc75-52ab-cf03-e5fa-36aee0d59bc5)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vvFrqacVS6B6",
        "colab_type": "text"
      },
      "source": [
        "## Imports"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "dYq8J8uBn9ZC",
        "colab_type": "code",
        "outputId": "db8ca8de-164c-4355-8abb-a493e7f9f393",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 136
        }
      },
      "source": [
        "from six.moves import cPickle\n",
        "import os\n",
        "import datetime\n",
        "import random\n",
        "\n",
        "import numpy as onp\n",
        "from matplotlib import pyplot as plt\n",
        "\n",
        "from jax.ops import index, index_update\n",
        "\n",
        "from tensor2tensor.trax import trax\n",
        "from tensor2tensor.trax import layers as tl\n",
        "from tensor2tensor.trax import inputs as trax_input\n",
        "from tensor2tensor.trax import models as trax_models\n",
        "from tensor2tensor.trax import optimizers as trax_optimizers\n",
        "from tensor2tensor.trax import backend\n",
        "from tensor2tensor.trax.backend import numpy as np\n",
        "from tensor2tensor.trax.backend import random as trax_random"
      ],
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "\n",
            "WARNING: The TensorFlow contrib module will not be included in TensorFlow 2.0.\n",
            "For more information, please see:\n",
            "  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n",
            "  * https://github.com/tensorflow/addons\n",
            "If you depend on functionality not listed there, please file an issue.\n",
            "\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zR6RVHx4lPzA",
        "colab_type": "text"
      },
      "source": [
        "# Toy Copy Problem\n",
        "\n",
        "Here we define batched random integer inputs for a trivial sequence-copy learning task."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "wGmWmpIslQYv",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "VOCAB_SIZE = 128\n",
        "def toy_problem_inputs(num_devices, batch_size=64,\n",
        "                       train_lengths=[10, 20], eval_lengths=[20]):\n",
        "  \"\"\"Make Inputs for the toy problem of the language 0w0w for w in [1..127]*.\n",
        "\n",
        "  Args:\n",
        "    num_devices: how many devices to build the inputs for (assert 1 for colab).\n",
        "    batch_size: how large are the batches.\n",
        "    train_lengths: lengths of w for training.\n",
        "    eval_lengths: lengths of w for eval.\n",
        "\n",
        "  Returns:\n",
        "    trax.inputs.Inputs\n",
        "  \"\"\"\n",
        "  assert num_devices == 1\n",
        "  def random_minibatches(length_list):\n",
        "    \"\"\"Generate a stream of random mini-batches.\"\"\"\n",
        "    while True:\n",
        "      length = random.choice(length_list)\n",
        "      w = onp.random.randint(low=1, high=VOCAB_SIZE-1,\n",
        "                            size=(batch_size, length // 2))\n",
        "      zero = onp.zeros([batch_size, 1], onp.int32)\n",
        "      x = onp.concatenate([zero, w, zero, w], axis=1)\n",
        "      yield (x, x)  # In a language model input and output are the same.\n",
        "\n",
        "  return trax_input.Inputs(\n",
        "      train_stream=lambda: random_minibatches(train_lengths),\n",
        "      train_eval_stream=lambda: random_minibatches(train_lengths),\n",
        "      eval_stream=lambda: random_minibatches(eval_lengths),\n",
        "      input_shape=(None,))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "eU0mpaf1lRky",
        "colab_type": "code",
        "outputId": "bf94086c-5d97-462b-b565-d4ba5f59b6c4",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 51
        }
      },
      "source": [
        "inputs = toy_problem_inputs(1)\n",
        "print(next(inputs.train_stream())[0][0])"
      ],
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "[  0  68  91  99 107 115 113 111  17 102  48   0  68  91  99 107 115 113\n",
            " 111  17 102  48]\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KvNaSWu5g2Vm",
        "colab_type": "text"
      },
      "source": [
        "## Baseline Transformer on Toy Problem"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "AGDmtrgcl73M",
        "colab_type": "code",
        "outputId": "4c0f12e9-10ec-4e67-9f15-d2cc7084c083",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 748
        }
      },
      "source": [
        "timestamp = datetime.datetime.now().strftime(\"%Y%m%d_%H%M\")\n",
        "output_dir = os.path.expanduser(\"~/trax_lm_%s\" % timestamp)\n",
        "def model(mode):\n",
        "  return trax_models.TransformerLM(\n",
        "      VOCAB_SIZE, feature_depth=128,\n",
        "      feedforward_depth=256, num_layers=3,\n",
        "      num_heads=4, mode=mode)\n",
        "_ = trax.train(model=model,\n",
        "               inputs=toy_problem_inputs,\n",
        "               output_dir=output_dir,\n",
        "               train_steps=3000,\n",
        "               eval_steps=10,\n",
        "               eval_frequency=1000)"
      ],
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Step      0: Starting training using 1 devices\n",
            "\n",
            "Step      1: Ran 1 train steps in 36.77 secs\n",
            "Step      1: Total trainable parameters size: 692736\n",
            "Step      1: Evaluation\n",
            "Step      1: train           accuracy |  0.00616714\n",
            "Step      1: train neg_log_perplexity | -5.06836748\n",
            "Step      1: train               loss |  5.06836748\n",
            "Step      1: eval            accuracy |  0.00610795\n",
            "Step      1: eval  neg_log_perplexity | -5.20451212\n",
            "Step      1: eval                loss |  5.20451212\n",
            "Step      1: Finished evaluation\n",
            "\n",
            "Step   1000: Ran 999 train steps in 89.13 secs\n",
            "Step   1000: Evaluation\n",
            "Step   1000: train           accuracy |  0.45719695\n",
            "Step   1000: train neg_log_perplexity | -2.71764731\n",
            "Step   1000: train               loss |  2.71764731\n",
            "Step   1000: eval            accuracy |  0.41278410\n",
            "Step   1000: eval  neg_log_perplexity | -2.94052887\n",
            "Step   1000: eval                loss |  2.94052887\n",
            "Step   1000: Finished evaluation\n",
            "\n",
            "Step   2000: Ran 1000 train steps in 15.61 secs\n",
            "Step   2000: Evaluation\n",
            "Step   2000: train           accuracy |  0.43169984\n",
            "Step   2000: train neg_log_perplexity | -2.82782769\n",
            "Step   2000: train               loss |  2.82782769\n",
            "Step   2000: eval            accuracy |  0.41278410\n",
            "Step   2000: eval  neg_log_perplexity | -2.92255998\n",
            "Step   2000: eval                loss |  2.92255998\n",
            "Step   2000: Finished evaluation\n",
            "\n",
            "Step   3000: Ran 1000 train steps in 15.64 secs\n",
            "Step   3000: Evaluation\n",
            "Step   3000: train           accuracy |  0.45053267\n",
            "Step   3000: train neg_log_perplexity | -2.73254609\n",
            "Step   3000: train               loss |  2.73254609\n",
            "Step   3000: eval            accuracy |  0.41249999\n",
            "Step   3000: eval  neg_log_perplexity | -2.92720962\n",
            "Step   3000: eval                loss |  2.92720962\n",
            "Step   3000: Finished evaluation\n",
            "Step   3000: Training done\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eapBBkRUuho7",
        "colab_type": "text"
      },
      "source": [
        "# Decoding from a Pre-Trained Transformer Language Model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "H6hVQ3v5iC00",
        "colab_type": "code",
        "outputId": "812949cc-4294-4a42-f55a-c40f65e151f8",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 187
        }
      },
      "source": [
        "# load model checkpoint\n",
        "with open(\"model.pkl\", \"rb\") as f:\n",
        "   (params, step, history) = cPickle.load(f, encoding=\"latin1\")\n",
        "\n",
        "# lm1b subword vocab\n",
        "def clean(x):\n",
        "  return x[1:-2]\n",
        "with open(\"vocab.lm1b.en.32768\", \"r\") as fp:\n",
        "  vocab = list(map(clean, fp.readlines()))\n",
        "vocab_map = {v:idx for idx,v in enumerate(vocab)}\n",
        "\n",
        "list(enumerate(vocab))[:10]"
      ],
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[(0, '<pad>_'),\n",
              " (1, '<EOS>_'),\n",
              " (2, 'the_'),\n",
              " (3, ' , _'),\n",
              " (4, ' ._'),\n",
              " (5, 'to_'),\n",
              " (6, 'of_'),\n",
              " (7, 'a_'),\n",
              " (8, 'and_'),\n",
              " (9, 'in_')]"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 6
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "W-7s9RXQNIru",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "tlm = trax_models.TransformerLM(\n",
        "  dropout=0.1, \n",
        "  feature_depth=512, \n",
        "  feedforward_depth=2048, \n",
        "  max_len=2048, \n",
        "  mode='eval', \n",
        "  num_heads=8, \n",
        "  num_layers=6, \n",
        "  vocab_size=32000)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "iLdtplDpdTMr",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def gumbel_sample(v, temperature=0.8):\n",
        "  u = onp.random.uniform(low=1e-9, high=1.0, size=v.shape)\n",
        "  g = -onp.log(-onp.log(u))\n",
        "  return np.argmax(v + g * temperature)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "IHSbtHzPjW6i",
        "colab_type": "code",
        "outputId": "7a8306b7-6c6b-41ba-c5aa-c1a76d9d8037",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 102
        }
      },
      "source": [
        "prompt = \"Please_\"\n",
        "num_samples = 5\n",
        "max_length = 20\n",
        "for _ in range(num_samples):\n",
        "  enc = [vocab_map[w] for w in str.split(prompt)]\n",
        "  pos = len(enc)\n",
        "  rng = trax_random.get_prng(0)\n",
        "  data = np.zeros((1, 50), dtype=np.int32)\n",
        "  data = index_update(data, index[1, 0:pos], enc)\n",
        "\n",
        "  while pos < max_length:\n",
        "    tmp = tlm(data, params=params, rng=rng)\n",
        "    next_sym = gumbel_sample(tmp[0, pos])\n",
        "    data = index_update(data, index[1, pos], next_sym)\n",
        "    pos += 1\n",
        "    if int(next_sym) == 1:\n",
        "      break\n",
        "\n",
        "  print(\"\".join([vocab[idx] for idx in onp.array(data)[0, 0:pos]]))"
      ],
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Please_write_to_him_to_tell_him_about_the_Wallace_and_Gromit_films_. _and_to_give_him_this_\n",
            "Please_do_not_turn_to_making_sure_your_children_are_already_in_school_or_that_you_have_school_ ._<EOS>_\n",
            "Please_read_the_full_prospectus_to_see_if_the_proposed_transaction_may_be_accurate_ ._<EOS>_\n",
            "Please_note_that_the_new_policy_has_been_strengthened_by_the_fact_that_Britney_Spears_ ' _mother_ , _Janet_Jackson_\n",
            "Please_ , _please_aim_at_your_brother_ , _if_you_want_to_ ._<EOS>_\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ym8otS7HpUIO",
        "colab_type": "text"
      },
      "source": [
        "# Transformer from Scratch\n",
        "\n",
        "Here we re-implement multiheaded self-attention and a transformer language model from scratch using only a few simple linear primitives from trax.\n",
        "\n",
        "Note in particular the commented modifications in the core  __DotProductAttention__ function as an example of how easy it is to modify layers and models for research using Trax."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "uw-GIdm2p_4X",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def DotProductAttention(query, key, value, mask, dropout, mode, rng, hard_k=4):\n",
        "  \"\"\"Core dot product self-attention.\n",
        "  Args:\n",
        "    query: array of representations\n",
        "    key: array of representations\n",
        "    value: array of representations\n",
        "    mask: attention-mask, gates attention\n",
        "    dropout: float: dropout rate\n",
        "    mode: 'eval' or 'train': whether to use dropout\n",
        "    rng: JAX PRNGKey: subkey for disposable use\n",
        "  Returns:\n",
        "    Self attention for q, k, v arrays.\n",
        "  \"\"\"\n",
        "  depth = np.shape(query)[-1]\n",
        "  dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth)\n",
        "  if mask is not None:\n",
        "    dots = np.where(mask, dots, -1e9)\n",
        "  # Softmax.\n",
        "  dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True))\n",
        "  # ----------------------------------------------------------------------\n",
        "  # As an example of a simple research modification, we modify the typical \n",
        "  # dot-product attention mechanism with top-k \"hard attention\":\n",
        "  # ----------------------------------------------------------------------\n",
        "  if hard_k > 0:\n",
        "    top_k = np.sort(dots)[..., -hard_k]  # Get the top-kth weight.\n",
        "    dots -= top_k[..., np.newaxis]  # Subtract (be 0 for lower ones).\n",
        "    dots = np.maximum(dots, 0)\n",
        "    dots /= np.sum(dots, axis=-1, keepdims=True)  # Re-normalize.\n",
        "  # ----------------------------------------------------------------------\n",
        "  if dropout >= 1.0:\n",
        "    raise ValueError('Dropout rates must be lower than 1.')\n",
        "  if dropout is not None and dropout > 0.0 and mode == 'train':\n",
        "    keep = backend.random.bernoulli(rng, 1.0 - dropout, dots.shape)\n",
        "    dots = np.where(keep, dots / (1.0 - dropout), 0)\n",
        "  out = np.matmul(dots, value)\n",
        "  # Uncomment to see an example TRAX stack trace to this point:\n",
        "  # ----------------------------------------------------------------------\n",
        "  # raise ValueError(\"err\")\n",
        "  # ----------------------------------------------------------------------\n",
        "  return out\n",
        "\n",
        "\n",
        "def _multihead_attention_output_shape(  # pylint: disable=invalid-name\n",
        "    input_shapes, **unused_kwargs):\n",
        "  \"\"\"Helper: calculate multihead attention output shape.\"\"\"\n",
        "  q_shape = input_shapes[0][0]  # Inputs are ((q, k, v), mask).\n",
        "  mask_shape = input_shapes[1]\n",
        "  return q_shape, mask_shape\n",
        "\n",
        "\n",
        "@tl.layer(output_shape=_multihead_attention_output_shape)\n",
        "def PureMultiHeadedAttention(x, params, num_heads=8, dropout=0.0,\n",
        "                             mode='train', **kwargs):\n",
        "  \"\"\"Pure transformer-style multi-headed attention.\n",
        "  Args:\n",
        "    x: inputs ((q, k, v), mask)\n",
        "    params: parameters (none)\n",
        "    num_heads: int: number of attention heads\n",
        "    dropout: float: dropout rate\n",
        "    mode: str: 'train' or 'eval'\n",
        "    **kwargs: other arguments including the rng\n",
        "  Returns:\n",
        "    Pure Multi-headed attention result, and the mask.\n",
        "  \"\"\"\n",
        "  del params\n",
        "  rng = kwargs.get('rng', None)\n",
        "  (q, k, v), mask = x\n",
        "  feature_depth = q.shape[-1]\n",
        "  assert feature_depth % num_heads == 0\n",
        "  head_depth = feature_depth // num_heads\n",
        "  nbatch = np.shape(q)[0]\n",
        "  # nbatch, seqlen, feature_depth --> nbatch, num_heads, seqlen, head_depth\n",
        "  def SplitHeads(x):\n",
        "    return np.transpose(\n",
        "        np.reshape(x, (nbatch, -1, num_heads, head_depth)), (0, 2, 1, 3))\n",
        "  # nbatch, num_heads, seqlen, head_depth --> nbatch, seqlen, feature_depth\n",
        "  def JoinHeads(x):  # pylint: disable=invalid-name\n",
        "    return np.reshape(\n",
        "        np.transpose(x, (0, 2, 1, 3)), (nbatch, -1, num_heads*head_depth))\n",
        "  # Split heads, dot-product attention, rejoin heads.\n",
        "  res = JoinHeads(\n",
        "      DotProductAttention(\n",
        "          SplitHeads(q), SplitHeads(k), SplitHeads(v), mask,\n",
        "          dropout=dropout, mode=mode, rng=rng))\n",
        "  return res, mask  # Keep the mask.\n",
        "\n",
        "\n",
        "def MultiHeadedAttentionQKV(\n",
        "    feature_depth, num_heads=8, dropout=0.0, mode='train'):\n",
        "  \"\"\"Transformer-style multi-headed attention.\n",
        "  Accepts inputs of the form (q, k, v), mask.\n",
        "  Args:\n",
        "    feature_depth: int:  depth of embedding\n",
        "    num_heads: int: number of attention heads\n",
        "    dropout: float: dropout rate\n",
        "    mode: str: 'train' or 'eval'\n",
        "  Returns:\n",
        "    Multi-headed self-attention result and the mask.\n",
        "  \"\"\"\n",
        "  return tl.Serial(\n",
        "      tl.Parallel(\n",
        "          tl.Parallel(\n",
        "              tl.Dense(feature_depth),\n",
        "              tl.Dense(feature_depth),\n",
        "              tl.Dense(feature_depth),\n",
        "          ),\n",
        "          tl.Copy()\n",
        "      ),\n",
        "      PureMultiHeadedAttention(  # pylint: disable=no-value-for-parameter\n",
        "          feature_depth=feature_depth, num_heads=num_heads,\n",
        "          dropout=dropout, mode=mode),\n",
        "      tl.Parallel(tl.Dense(feature_depth), tl.Copy())\n",
        "  )\n",
        "\n",
        "\n",
        "def MultiHeadedAttention(\n",
        "    feature_depth, num_heads=8, dropout=0.0, mode='train'):\n",
        "  \"\"\"Transformer-style multi-headed attention.\n",
        "  Accepts inputs of the form (x, mask) and constructs (q, k, v) from x.\n",
        "  Args:\n",
        "    feature_depth: int:  depth of embedding\n",
        "    num_heads: int: number of attention heads\n",
        "    dropout: float: dropout rate\n",
        "    mode: str: 'train' or 'eval'\n",
        "  Returns:\n",
        "    Multi-headed self-attention layer.\n",
        "  \"\"\"\n",
        "  return tl.Serial(\n",
        "      tl.Parallel(\n",
        "          # q = k = v = first input\n",
        "          tl.Branch(\n",
        "              tl.Copy(), tl.Copy(), tl.Copy()),\n",
        "          tl.Copy()  # pass the mask\n",
        "      ),\n",
        "      MultiHeadedAttentionQKV(  # pylint: disable=no-value-for-parameter\n",
        "          feature_depth, num_heads=num_heads, dropout=dropout, mode=mode),\n",
        "  )"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Ge42t7VZl-d2",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def ResidualFeedForward(feature_depth,\n",
        "                        feedforward_depth,\n",
        "                        dropout,\n",
        "                        mode):\n",
        "  \"\"\"Residual feed-forward layer with normalization at start.\"\"\"\n",
        "  return tl.Residual(\n",
        "      tl.LayerNorm(),\n",
        "      tl.Dense(feedforward_depth),\n",
        "      tl.Relu(),\n",
        "      tl.Dropout(rate=dropout, mode=mode),\n",
        "      tl.Dense(feature_depth),\n",
        "      tl.Dropout(rate=dropout, mode=mode)\n",
        "  )\n",
        "\n",
        "\n",
        "def DecoderLayer(feature_depth,\n",
        "                 feedforward_depth,\n",
        "                 num_heads,\n",
        "                 dropout,\n",
        "                 mode):\n",
        "  \"\"\"Transformer decoder layer.\n",
        "  Args:\n",
        "    feature_depth: int:  depth of embedding\n",
        "    feedforward_depth: int: depth of feed-forward layer\n",
        "    num_heads: int: number of attention heads\n",
        "    dropout: float: dropout rate (how much to drop out)\n",
        "    mode: str: 'train' or 'eval'\n",
        "  Returns:\n",
        "    the layer.\n",
        "  \"\"\"\n",
        "  return tl.Serial(\n",
        "      tl.Residual(  # Self-attention block.\n",
        "          tl.LayerNorm(),\n",
        "          tl.Branch(tl.Copy(), tl.CausalMask(axis=-2)),  # Create mask.\n",
        "          # We replace the \"stock\" self-attention layer with the one defined\n",
        "          # above:\n",
        "          # tl.MultiHeadedAttention(feature_depth, num_heads=num_heads,\n",
        "          #                         dropout=dropout, mode=mode),\n",
        "          MultiHeadedAttention(feature_depth, num_heads=num_heads,\n",
        "                                  dropout=dropout, mode=mode),\n",
        "          tl.Select(0),  # Drop the mask.\n",
        "          tl.Dropout(rate=dropout, mode=mode)\n",
        "      ),\n",
        "      ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode)\n",
        "  )\n",
        "\n",
        "\n",
        "def TransformerLM(vocab_size,\n",
        "                  feature_depth=512,\n",
        "                  feedforward_depth=2048,\n",
        "                  num_layers=6,\n",
        "                  num_heads=8,\n",
        "                  dropout=0.1,\n",
        "                  max_len=2048,\n",
        "                  mode='train'):\n",
        "  \"\"\"Transformer language model (only uses the decoder part of Transformer).\n",
        "  Args:\n",
        "    vocab_size: int: vocab size\n",
        "    feature_depth: int:  depth of embedding\n",
        "    feedforward_depth: int: depth of feed-forward layer\n",
        "    num_layers: int: number of encoder/decoder layers\n",
        "    num_heads: int: number of attention heads\n",
        "    dropout: float: dropout rate (how much to drop out)\n",
        "    max_len: int: maximum symbol length for positional encoding\n",
        "    mode: str: 'train' or 'eval'\n",
        "  Returns:\n",
        "    the layer.\n",
        "  \"\"\"\n",
        "  return tl.Serial(\n",
        "      tl.ShiftRight(),\n",
        "      tl.Embedding(feature_depth, vocab_size),\n",
        "      tl.Dropout(rate=dropout, mode=mode),\n",
        "      tl.PositionalEncoding(max_len=max_len),\n",
        "      tl.Serial(*[DecoderLayer(feature_depth, feedforward_depth, num_heads,\n",
        "                               dropout, mode)\n",
        "                  for _ in range(num_layers)]),\n",
        "      tl.LayerNorm(),\n",
        "      tl.Dense(vocab_size),\n",
        "      tl.LogSoftmax()\n",
        "  )"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "WZxnwjAEqYDh",
        "colab_type": "code",
        "outputId": "f90e965d-2625-4e56-9038-65c087639051",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 748
        }
      },
      "source": [
        "timestamp = datetime.datetime.now().strftime(\"%Y%m%d_%H%M\")\n",
        "output_dir = os.path.expanduser(\"~/trax_lm_%s\" % timestamp)\n",
        "def new_model(mode):\n",
        "  return TransformerLM(\n",
        "      VOCAB_SIZE, feature_depth=128,\n",
        "      feedforward_depth=256, num_layers=3,\n",
        "      num_heads=4, mode=mode)\n",
        "_ = trax.train(model=new_model,\n",
        "           inputs=toy_problem_inputs,\n",
        "           output_dir=output_dir,\n",
        "           train_steps=3000,\n",
        "           eval_steps=10,\n",
        "           eval_frequency=1000)"
      ],
      "execution_count": 22,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Step      0: Starting training using 1 devices\n",
            "\n",
            "Step      1: Ran 1 train steps in 42.29 secs\n",
            "Step      1: Total trainable parameters size: 692736\n",
            "Step      1: Evaluation\n",
            "Step      1: train           accuracy |  0.00686553\n",
            "Step      1: train neg_log_perplexity | -5.42891455\n",
            "Step      1: train               loss |  5.42891455\n",
            "Step      1: eval            accuracy |  0.00809659\n",
            "Step      1: eval  neg_log_perplexity | -5.39403439\n",
            "Step      1: eval                loss |  5.39403439\n",
            "Step      1: Finished evaluation\n",
            "\n",
            "Step   1000: Ran 999 train steps in 109.64 secs\n",
            "Step   1000: Evaluation\n",
            "Step   1000: train           accuracy |  0.12875238\n",
            "Step   1000: train neg_log_perplexity | -4.29979420\n",
            "Step   1000: train               loss |  4.29979420\n",
            "Step   1000: eval            accuracy |  0.09928977\n",
            "Step   1000: eval  neg_log_perplexity | -4.45948172\n",
            "Step   1000: eval                loss |  4.45948172\n",
            "Step   1000: Finished evaluation\n",
            "\n",
            "Step   2000: Ran 1000 train steps in 16.89 secs\n",
            "Step   2000: Evaluation\n",
            "Step   2000: train           accuracy |  0.53104877\n",
            "Step   2000: train neg_log_perplexity | -2.33383632\n",
            "Step   2000: train               loss |  2.33383632\n",
            "Step   2000: eval            accuracy |  0.54900569\n",
            "Step   2000: eval  neg_log_perplexity | -2.24813342\n",
            "Step   2000: eval                loss |  2.24813342\n",
            "Step   2000: Finished evaluation\n",
            "\n",
            "Step   3000: Ran 1000 train steps in 16.91 secs\n",
            "Step   3000: Evaluation\n",
            "Step   3000: train           accuracy |  0.56715208\n",
            "Step   3000: train neg_log_perplexity | -2.15219927\n",
            "Step   3000: train               loss |  2.15219927\n",
            "Step   3000: eval            accuracy |  0.54928976\n",
            "Step   3000: eval  neg_log_perplexity | -2.25436211\n",
            "Step   3000: eval                loss |  2.25436211\n",
            "Step   3000: Finished evaluation\n",
            "Step   3000: Training done\n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}